#include <UnitTest++.h>
#include "config.h"
#include "mpi_master_impl.h"
#include <tbb/concurrent_queue.h>
#include "xfer.h"

using namespace mcrx;

/** These are tests for the mpi_master class. For them to make any
    sense, this function has to be run from mpirun so there's more
    than one task to communicate. 
*/

class test_grid {
public:
  typedef char T_location_request;
  typedef char T_location_response;

  T_location_response location_request(T_location_request&) { return 0; };
};

class test_xfer {
public:
  enum { Send_wait };
  typedef test_grid T_grid;
  typedef ray<array_1> T_ray;
  //typedef ray<blitz::TinyVector<double, 10> > T_ray;
  typedef queue_item<T_ray, hilbert::cell_code> T_queue_item;
  typedef tbb::concurrent_bounded_queue<T_queue_item> T_ray_queue;

  typedef std::pair<int,blitz::TinyVector<T_float, 2> > T_pixel_queue_item;
  typedef tbb::concurrent_queue<T_pixel_queue_item> T_pixel_queue;

  mutable boost::shared_ptr<T_ray_queue> local_pending_;
  mutable boost::shared_ptr<T_pixel_queue> px_queue_;
  T_grid g_;
  const static int len=120;

  test_xfer() : 
    local_pending_(new T_ray_queue()), 
    px_queue_(new T_pixel_queue()), g_() {};

  boost::shared_ptr<T_ray_queue> local_queue() const {return local_pending_; };
  boost::shared_ptr<T_pixel_queue> pixel_queue() const {return px_queue_; };
  T_grid& grid() { return g_; };
  void set_master(const mpi_master<test_xfer>*) const {};
  void init_queue_item(T_queue_item& i) const {
    T_ray r; 
    T_densities col(1); r.set_column_ref(col);
    T_ray::T_lambda intens(len),norm(len);
    intens=2*blitz::tensor::i;
    norm=4*blitz::tensor::i;
    r.set_intensity_ref(intens);
    i= T_queue_item(T_queue_item::propagate, r,  hilbert::cell_code(42,2),
		    norm, 4711);
  }
};

TEST(construction)
{
  mpienv::init(0,0);
  mpi_barrier();
  hpm::disable();
  test_xfer x;
  tbb::atomic<long> nr;
  mpi_master<test_xfer> m(x, 0, nr, 0);
  mpi_barrier();
}

TEST(run)
{
  mpi_barrier();
  test_xfer x;
  tbb::atomic<long> nr; nr=0;
  mpi_master<test_xfer> m(x, 0, nr, 0);

  // this should exit immediately
  m.run();
  mpi_barrier();
}

class mpimaster_fixture {
public:
  mpimaster_fixture() : x(), nr(), nrd(1000), m(x,1,nr,nrd) { nr=0; };

  test_xfer x;
  tbb::atomic<long> nr;
  long nrd;
  mpi_master<test_xfer> m;
};

class test_thread {
public:
  typedef test_xfer T_xfer;
  typedef T_xfer::T_queue_item T_queue_item; 
  typedef T_xfer::T_ray T_ray;

  // don't copy the xfer, just keep a reference. we just need the queues
  test_xfer& x_; 
  mpi_master<test_xfer>& m_;
  tbb::atomic<long>& nr_;
  long nrd_;

  test_thread(test_xfer& x, mpi_master<test_xfer>& m,
	      tbb::atomic<long>& nr, long nrd) : 
    x_(x), m_(m), nr_(nr), nrd_(nrd) {};

  void operator() ();
};

void test_thread::operator()()
{
  T_ray r;
  r.set_position(vec3d(1,2,3));
  r.set_direction(vec3d(0,1,0));
  T_densities col(1);
  r.set_column_ref(col);
  T_ray::T_lambda intens(x_.len),norm(x_.len);
  intens=2*blitz::tensor::i;
  norm=4*blitz::tensor::i;
  r.set_intensity_ref(intens);

  T_queue_item q(T_queue_item::propagate, r, hilbert::cell_code(42,2),
		 norm, 4711);
  int dest=mpi_rank()+1; 
  if(dest==mpi_size()) dest=0;

  for(int i=0; i<nrd_; ++i) {
    m_.thread_send_ray(q, 0, dest);
    ++nr_;
  }

  while(true) {
    T_queue_item q;
    x_.local_queue()->pop(q);
    if(q.action_==test_xfer::T_queue_item::invalid)
      break;
    CHECK_EQUAL(hilbert::cell_code(42,2), q.cell_);
    //CHECK_ARRAY_EQUAL(norm.dataFirst(), q.norm_.dataFirst(), norm.length());//size());
    CHECK_ARRAY_EQUAL(norm.dataFirst(), q.norm_.dataFirst(), norm.size());
  }  
  
  printf("worker %d exited\n",mpi_rank());
}

TEST_FIXTURE(mpimaster_fixture, spawn_threads)
{
  mpi_barrier();
  test_thread t(x,m, nr, nrd);
  boost::thread_group threads;
  threads.create_thread(boost::ref(t));

  m.run();
  printf("master %d exited\n",mpi_rank());
  threads.join_all();
  CHECK_EQUAL(mpi_rank()==0?nrd*mpi_size():0, nr);
  mpi_barrier();
}
